import pandas as pd
import numpy as np
import torch
from transformers import pipeline, logging
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from sklearn.utils import resample
from datasets import load_dataset, DatasetDict, Dataset
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################

def label_docs(model, docs_dict, batch_size = 8, device = device):
    """
    Passes documents through the pipeline. Returns a list of entail, not_entail labels
    """
    pipe = pipeline(task = 'text-classification', model = model, 
                    batch_size = batch_size, device = device, 
                    max_length = 512, truncation = True, 
                    torch_dtype = torch.bfloat16)
    res = pipe(docs_dict)
    res = [result['label'] for result in res]
    return res

# models that will be tested
models = ["MoritzLaurer/deberta-v3-base-zeroshot-v2.0", 
          "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
          "mlburnham/Political_DEBATE_DeBERTa_base_v1.1",
          "mlburnham/Political_DEBATE_DeBERTa_large_v1.1",
          "mlburnham/Political_DEBATE_ModernBERT_base_v1.0",
          "mlburnham/Political_DEBATE_ModernBERT_large_v1.0"]

# column names that will hold results
columns = ['base_nli',
           'large_nli',
           'base_debate',
           'large_debate',
           'base_modern',
           'large_modern']

################
## UKP Stance
################
ukp = pd.read_csv('./data/ukp_stance.csv')
docs_dict = [{'text':ukp.loc[i, 'text'], 'text_pair':ukp.loc[i, 'hypothesis']} for i in ukp.index]
# for each model, classify documents and return labels to the test dataframe
for modname, col in zip(models, columns):
    res = label_docs(modname, docs_dict, batch_size = 8, device = device)
    ukp[col] = res
    ukp[col] = ukp[col].replace({'entailment': 0, 'not_entailment': 1})
    print(modname + ' complete.')
ukp.to_csv('./data/ukp_stance.csv', index = False)

################
## UKP Topic
################
topic = pd.read_csv('./data/ukp_topic.csv')
docs_dict = [{'text':topic.loc[i, 'text'], 'text_pair':topic.loc[i, 'hypothesis']} for i in topic.index]

# for each model, classify documents and return labels to the test dataframe
for modname, col in zip(models, columns):
    res = label_docs(modname, docs_dict, batch_size = 8, device = device)
    topic[col] = res
    topic[col] = topic[col].replace({'entailment': 0, 'not_entailment': 1})
    print(modname + ' complete.')
topic.to_csv('./data/ukp_topic.csv', index = False)

################
## RAND Event
################
rand = pd.read_csv('./data/rand_terror.csv')

for modname, col in zip(models, columns):
    pipe = pipeline(task = 'zero-shot-classification', model = modname, 
                        batch_size = 16, device = device, 
                        max_length = 512, truncation = True, torch_dtype = torch.bfloat16)
    
    labels = list(rand['hypothesis'].unique())
    
    res = pipe(list(rand['premise']), candidate_labels = labels, template = {})
    res = [result['labels'][0] for result in res]
    rand[col] = res
    print(modname + ' complete.')

for col in columns:
    rand[col].replace({'This text describes an explosives attack.': 1,
       'This text describes a firearms attack.': 2,
       'This text describes an arson attack.': 3,
       'This text describes a knife or sharp object attack.': 4,
       'This text describes a biological agent attack.': 5,
       'This text describes a chemical agent attack.': 6}, inplace = True)

rand.to_csv('./data/rand_terror.csv', index = False)
########################
## Deliberative Politics
########################
dp = pd.read_csv('./data/deliberative_politics.csv')
docs_dict = [{'text':dp.loc[i, 'text'], 'text_pair':dp.loc[i, 'hypothesis']} for i in dp.index]

# for each model, classify documents and return labels to the test dataframe
for modname, col in zip(models, columns):
    res = label_docs(modname, docs_dict, batch_size = 8, device = device)
    dp[col] = res
    dp[col] = dp[col].replace({'entailment': 0, 'not_entailment': 1})
    print(modname + ' complete.')
dp.to_csv('./data/deliberative_politics.csv', index = False)